import numpy as np
from tqdm import tqdm
import concurrent.futures
from functools import partial
from faster_caching import *
from plot_utils import *

# set error handling
import warnings

def uncached_tokens_to_latency(uncached_tokens):
    # read from ./vllm_latency_experiments_repeats.csv if it's not cached
    # otherwise, read from ./vllm_latency_experiments_cached.csv
    if uncached_tokens == 0:
        return 0
    elif table_cached:
        return table_cached[uncached_tokens]
    else:
        # read from ./vllm_latency_experiments_repeats.csv
        table_cached = pd.read_csv("./vllm_latency_experiments_repeats.csv")
        return table_cached[uncached_tokens]

def run_single_policy(policy_func, a_list, C, xi=None, Q=None, predictor=None, forced=None, threshold=None):
    """Run a single cache policy with the appropriate parameters."""
    if policy_func.__name__ == 'tail_optimized_Belady_cache_policy':
        return policy_func(a_list, C, xi, forced)
    elif policy_func.__name__ == 'tail_optimized_LRU_cache_policy':
        return policy_func(a_list, C, xi, Q, predictor=predictor, forced=forced)
    elif policy_func.__name__ in ['Belady_cache_policy', 'LRU_cache_policy']:
        return policy_func(a_list, C, forced)
    elif policy_func.__name__ == 'thre_lru_cache_policy':
        return policy_func(a_list, C, threshold, forced)
    else:
        raise ValueError(f"Unknown policy function: {policy_func.__name__}")

def process_cache_capacity(C, a_list, xi, Q, forced, percentiles):
    """Process a single cache capacity value using parallel execution for each policy."""
    
    # Define all policy configurations
    policy_configs = [
        {'func': tail_optimized_Belady_cache_policy, 'predictor': None, 'name': 'belady', 'threshold': None},
        {'func': tail_optimized_LRU_cache_policy, 'predictor': 'None', 'name': 'lru', 'threshold': None},
        {'func': tail_optimized_LRU_cache_policy, 'predictor': 'End', 'name': 'lru_end', 'threshold': None},
        {'func': tail_optimized_LRU_cache_policy, 'predictor': 'Perfect', 'name': 'lru_perfect', 'threshold': None},
        {'func': LRU_cache_policy, 'predictor': None, 'name': 'vanilla_lru', 'threshold': None},
        {'func': thre_lru_cache_policy, 'predictor': None, 'threshold': 1024, 'name': 'thre_lru'}
    ]
    
    # Execute each policy in parallel using ThreadPoolExecutor
    # (ThreadPoolExecutor is appropriate here as these functions might be I/O bound or GIL-friendly)
    results = {}
    policy_results = {}
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit all policy execution tasks
        future_to_policy = {}
        for config in policy_configs:
            future = executor.submit(
                run_single_policy,
                policy_func=config['func'],
                a_list=a_list,
                C=C,
                xi=xi,
                Q=Q,
                predictor= config['predictor'],
                forced=forced,
                threshold=config['threshold']
            )
            future_to_policy[future] = config['name']
        
        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_policy):
            policy_name = future_to_policy[future]
            try:
                policy_results[policy_name] = future.result()
            except Exception as exc:
                print(f"Policy {policy_name} generated an exception: {exc}")
                policy_results[policy_name] = None
    
    # Calculate percentiles and store results
    for p in percentiles:
        for policy_name, uncached_results in policy_results.items():
            if uncached_results is not None:
                results[f'{policy_name}_{p}'] = np.percentile(uncached_results, p)
            else:
                results[f'{policy_name}_{p}'] = np.nan  # Handle failed policy calculations
    
    # Store the capacity value for sorting results later
    results['capacity'] = C
    return results

def run_parallel_cache_evaluation(C_values, a_list, xi, Q, forced, percentiles, max_workers=None):
    """Run cache policy evaluation with two levels of parallelism."""
    
    # Create a partial function with fixed parameters
    process_fn = partial(
        process_cache_capacity,
        a_list=a_list,
        xi=xi,
        Q=Q,
        forced=forced,
        percentiles=percentiles
    )
    
    # Initialize result dictionaries
    policy_names = ['belady', 'lru', 'lru_end', 'lru_perfect', 'vanilla_lru', 'thre_lru']
    results = {f"{policy}_{p}": [] for policy in policy_names for p in percentiles}
    
    # Create progress bar for total tasks
    total_tasks = len(C_values)
    progress_bar = tqdm(total=total_tasks * len(policy_names) * len(percentiles), desc="Testing cache capacities")
    
    # Process capacities in parallel
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        # Submit all capacity processing tasks
        future_to_capacity = {executor.submit(process_fn, C): C for C in C_values}
        
        # Process results as they complete
        completed_capacities = []
        for future in concurrent.futures.as_completed(future_to_capacity):
            try:
                result = future.result()
                C = result['capacity']
                completed_capacities.append(C)
                
                # Update results dictionaries
                for policy in policy_names:
                    for p in percentiles:
                        results[f"{policy}_{p}"].append(result[f'{policy}_{p}'])
                
                        # Update progress bar
                        progress_bar.update(1)
            except Exception as exc:
                print(f"Capacity processing generated an exception: {exc}")
                # Handle the error appropriately
    
    # Close the progress bar
    progress_bar.close()
    
    # Reorder results based on original C_values order
    if len(completed_capacities) == len(C_values):
        # Create a mapping from completed capacity to its position in the results
        capacity_to_position = {C: i for i, C in enumerate(completed_capacities)}
        
        # Create ordered indices
        ordered_indices = [capacity_to_position[C] for C in C_values]
        
        # Reorder all results
        for key in results:
            results[key] = [results[key][i] for i in ordered_indices]
    
    # Reformat results to match the expected return structure
    formatted_results = {}
    for policy in policy_names:
        formatted_results[policy] = {p: results[f"{policy}_{p}"] for p in percentiles}
    
    return (
        formatted_results['belady'],
        formatted_results['lru'],
        formatted_results['lru_end'],
        formatted_results['lru_perfect'],
        formatted_results['vanilla_lru'],
        formatted_results['thre_lru']
    )

# Example usage
if __name__ == "__main__":
    # Your existing parameters
    # import os
    # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
    C_values = [1000,2000,3000,4000,5000,6000,7000,8000,9000]  # Your cache capacity values
    a_list = load_data("ShareGPT")    # Your access list
    # change to dataframe
    a_list = pd.DataFrame(a_list)
    # filter the dataframe to only include conv_idx 1-500
    # max_conv_idx is 100
    max_conv_idx = 200
    a_list = a_list[a_list['conv_idx'].isin(range(1, max_conv_idx + 1))]

    # import pdb; pdb.set_trace()
    
    
    # a_list = a_list.to_dict(orient='records')
    forced = 0     # Your forced parameter
    percentiles = [50, 90, 95, 99]  # Your percentiles
    Q = 100  # Fixed Q value
    xi_values = [694, 1498, 2302, 3107, 7934]  # List of xi values to test
    results_by_xi = {}
    
    for xi in xi_values:
        print(f"Running for xi={xi}, Q={Q}")
        set_name_run(f"xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}")
        # Run the parallel evaluation
        (
            belady_results, 
            lru_results, 
            lru_end_results, 
            lru_perfect_results, 
            vanilla_lru_results,
            thre_lru_results
        ) = run_parallel_cache_evaluation(
            C_values, 
            a_list, 
            xi, 
            Q, 
            forced, 
            percentiles,
            max_workers=None  # None means use all available CPU cores
        )
        
        # Store results for this xi value
        results_dict = {
            'tail_belady': belady_results,
            'tail_lru_none': lru_results,
            'tail_lru_end': lru_end_results,
            'tail_lru_perfect': lru_perfect_results,
            'vanilla_lru': vanilla_lru_results,
            'thre_lru': thre_lru_results
        }
        results_by_xi[xi] = results_dict
        
        # Create basic plots for individual xi runs
        fig, axes = create_cache_policy_plots(
            C_values, 
            belady_results, 
            lru_results, 
            lru_end_results, 
            lru_perfect_results, 
            vanilla_lru_results, 
            thre_lru_results,
            percentiles,
            title=f"Tail Metrics of Uncached Tokens vs. Cache Capacity (xi={xi})",
            filename=f"tail_metrics_comparison_xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}.png",
            log_scale=(False, True)  # log scale for y-axis only
        )
        
        # Save results to pkl file
        import pickle
        import os
        os.makedirs(f"./results/xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}", exist_ok=True)
        with open(f"./results/xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}/parallel_caching.pkl", "wb") as f:
            pickle.dump(results_dict, f)
        print(f"Results saved to ./results/xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}/parallel_caching.pkl")
    
    # Generate plots comparing different xi values for each percentile
    for p in percentiles:
        create_xi_comparison_plots(
            C_values,
            results_by_xi,
            p,
            filename=f"xi_comparison_p{p}_Q{Q}forced{forced}_maxconvidx{max_conv_idx}.png",
            title=f"P{p} Comparison Across Different xi Values (Q={Q})"
        )